{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "import codecs\n",
    "import re\n",
    "import nltk\n",
    "from pyspark import SparkConf, SparkContext\n",
    "from pyspark.sql import SparkSession, Row\n",
    "from pyspark.sql import functions as F\n",
    "from pyspark.sql.types import StringType, IntegerType\n",
    "from tqdm import tqdm\n",
    "from pyspark.ml.feature import HashingTF, IDF, Tokenizer\n",
    "from pyspark.ml.clustering import KMeans\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.ml.pipeline import Pipeline\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "from pyspark.ml.linalg import Vectors, VectorUDT\n",
    "from pyspark.sql.types import StructField, StructType"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#Create SparkConf\n",
    "sparkConf =SparkConf().setAppName('SentenceVectorExplorationV2').setMaster('local[*]')\n",
    "#Create SparkContext\n",
    "sc=SparkContext(conf=sparkConf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark = SparkSession(sc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.read.option(\"header\", \"true\").csv(\"./data/clean/train/*\", sep= \"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+--------------------+--------------------+\n",
      "|     genre|           sentence1|           sentence2|\n",
      "+----------+--------------------+--------------------+\n",
      "| telephone|yeah pay fee amer...|american express ...|\n",
      "|government|now will embark r...|people reviewing ...|\n",
      "| telephone|umhum okay think ...|oak tree leaves p...|\n",
      "|government|fiscal year 2002 ...|beginning fiscal ...|\n",
      "| telephone|well argument one...|statistically con...|\n",
      "+----------+--------------------+--------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_rdd = df.rdd.map(lambda x: (x.genre, \"{} {}\".format(x.sentence1, x.sentence2))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_agg = spark.createDataFrame(temp_rdd, schema= StructType([\n",
    "    StructField(\"genre\", StringType()),\n",
    "    StructField(\"sentence\", StringType())\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def partition_sentence2vector(partition):\n",
    "    embed = hub.Module(\"./data/pretrained/\")\n",
    "    #rev_text_list = sentence.split(\" \")\n",
    "    generes = []\n",
    "    sentences = []\n",
    "    for item in partition:\n",
    "        generes.append(item[0])\n",
    "        sentences.append(item[1])\n",
    "    #print(generes)\n",
    "    with tf.Session() as session:\n",
    "        session.run([tf.global_variables_initializer(), tf.tables_initializer()])\n",
    "        message_embeddings = session.run(embed(sentences))\n",
    "    # return  [(generes[i], message_embeddings[i]) for i in range(message_embeddings.shape[0])]\n",
    "    return [(genre,vector) for genre,vector in zip(generes,message_embeddings)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_rdd = df_agg.rdd.map(lambda x: (x.genre,x.sentence)).mapPartitions(partition_sentence2vector).map(lambda x: (x[0], Vectors.dense(x[1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "df_new = temp_rdd.toDF( schema= StructType([\n",
    "    StructField(\"genre\", StringType()),\n",
    "    StructField(\"sentenceVec\", VectorUDT())\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "kmeans = KMeans(featuresCol=\"sentenceVec\", predictionCol=\"predCluster\", k= 5, seed= 1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model = kmeans.fit(df_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_result = model.transform(df_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+--------------------+-----------+\n",
      "|     genre|         sentenceVec|predCluster|\n",
      "+----------+--------------------+-----------+\n",
      "| telephone|[0.04928012937307...|          2|\n",
      "|government|[-0.0333102568984...|          2|\n",
      "| telephone|[0.03496768698096...|          0|\n",
      "|government|[0.01194196287542...|          2|\n",
      "| telephone|[0.03472994640469...|          2|\n",
      "+----------+--------------------+-----------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df_result.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+----------+\n",
      "|predCluster|     genre|\n",
      "+-----------+----------+\n",
      "|          0|    travel|\n",
      "|          1| telephone|\n",
      "|          2|government|\n",
      "|          3|     slate|\n",
      "|          4|   fiction|\n",
      "+-----------+----------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df_temp = df_result.groupBy([\"predCluster\", \"genre\"]).count().groupBy([\"predCluster\"]).agg(F.max(\"count\").alias(\"count\"))\n",
    "df_temp2 = df_result.groupBy([\"predCluster\", \"genre\"]).count().drop(\"predCluster\")\n",
    "df_join = df_temp.join(df_temp2, on=\"count\", how=\"inner\")\n",
    "df_cluster2genre = df_join.select([\"predCluster\", \"genre\"]).sort(\"predCluster\")\n",
    "df_cluster2genre.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* ~~由上面的第二个表可以看出, 在考虑每个簇被分为哪类的时候, 按照0,2,1,3,4的顺序进行考虑~~\n",
    "* 由上面的表可以看出:\n",
    "    * cluster0 认为是`travel`\n",
    "    * cluster1 认为是`telephone`\n",
    "    * cluster2 认为是`government`\n",
    "    * cluster3 认为是`slate`\n",
    "    * cluster4 认为是`fiction`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster2genre = {}\n",
    "for cluster_id, genre_name in df_cluster2genre.select([\"predCluster\", \"genre\"]).rdd.map(lambda x: (x.predCluster, x.genre)).take(5):\n",
    "    cluster2genre[cluster_id] = genre_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_cluster2genre(x):\n",
    "    return cluster2genre[x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "udf_map_cluster2genre = F.udf(map_cluster2genre, StringType())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_result_2 = df_result.withColumn(\"predGenre\", udf_map_cluster2genre(df_result.predCluster))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_labels = []\n",
    "pred_labels = []\n",
    "for row in df_result_2.select([\"genre\", \"predGenre\"]).take(df_result_2.count()):\n",
    "    true_labels.append(row.genre)\n",
    "    pred_labels.append(row.predGenre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_result_3 = df_result_2.groupBy([\"genre\", \"predGenre\"]).count().sort([\"genre\", \"predGenre\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+-----+\n",
      "|     genre|count|\n",
      "+----------+-----+\n",
      "|    travel|77350|\n",
      "|     slate|77305|\n",
      "|   fiction|77272|\n",
      "|government|77350|\n",
      "| telephone|83345|\n",
      "+----------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy(\"genre\").count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "genre2total = {\n",
    "    \"travel\": 77350,\n",
    "    \"slate\": 77305,\n",
    "    \"fiction\":77272,\n",
    " \"government\":77350,\n",
    " \"telephone\": 83345\n",
    "}\n",
    "def get_percentage(genre, count):\n",
    "#     return \"{}\".format(count)\n",
    "    return \"{:.2f}%\".format((float(count) / genre2total[genre] * 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "udf_get_ratio = F.udf(get_percentage, StringType())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+----------+-----+-------+\n",
      "|     genre| predGenre|count|percent|\n",
      "+----------+----------+-----+-------+\n",
      "|   fiction|   fiction|48396| 62.63%|\n",
      "|   fiction|government| 1719|  2.22%|\n",
      "|   fiction|     slate| 7047|  9.12%|\n",
      "|   fiction| telephone|16068| 20.79%|\n",
      "|   fiction|    travel| 4042|  5.23%|\n",
      "|government|   fiction| 2356|  3.05%|\n",
      "|government|government|65317| 84.44%|\n",
      "|government|     slate| 7481|  9.67%|\n",
      "|government| telephone| 1059|  1.37%|\n",
      "|government|    travel| 1137|  1.47%|\n",
      "|     slate|   fiction|17865| 23.11%|\n",
      "|     slate|government|12128| 15.69%|\n",
      "|     slate|     slate|38691| 50.05%|\n",
      "|     slate| telephone| 5015|  6.49%|\n",
      "|     slate|    travel| 3606|  4.66%|\n",
      "| telephone|   fiction|43330| 51.99%|\n",
      "| telephone|government| 8641| 10.37%|\n",
      "| telephone|     slate|10986| 13.18%|\n",
      "| telephone| telephone|16629| 19.95%|\n",
      "| telephone|    travel| 3759|  4.51%|\n",
      "+----------+----------+-----+-------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "confusion_table = df_result_3.withColumn(\"percent\", udf_get_ratio(df_result_3[\"genre\"], df_result_3[\"count\"]))\n",
    "confusion_table.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = confusion_table.select([\"genre\", \"predGenre\", \"percent\"]).take(confusion_table.count())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_data = {}\n",
    "for row in result:\n",
    "    temp = result_data.get(row.genre, {})\n",
    "    temp[row.predGenre] = row.percent\n",
    "    result_data[row.genre] = temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "genres = list(genre2total.keys())\n",
    "pd_data = [{} for i in range(len(genres))]\n",
    "for genre1 in genres:\n",
    "    for index, genre2 in enumerate(genres):\n",
    "        pd_data[index][genre1] = result_data[genre1].get(genre2, \"0.00%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "            travel   slate fiction government telephone\n",
      "travel      80.23%   4.66%   5.23%      1.47%     4.51%\n",
      "slate        9.91%  50.05%   9.12%      9.67%    13.18%\n",
      "fiction      7.16%  23.11%  62.63%      3.05%    51.99%\n",
      "government   2.20%  15.69%   2.22%     84.44%    10.37%\n",
      "telephone    0.51%   6.49%  20.79%      1.37%    19.95%\n"
     ]
    }
   ],
   "source": [
    "print(pd.DataFrame(pd_data, index= genres))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
